# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import sys
from hysop import __DEBUG__
from hysop.constants import HYSOP_DEFAULT_TASK_ID, Backend, MemoryOrdering
from hysop.core.checkpoints import CheckpointHandler
from hysop.core.graph.computational_graph import ComputationalGraph
from hysop.tools.contexts import Timer
from hysop.tools.decorators import debug, profile
from hysop.tools.parameters import MPIParams
from hysop.tools.string_utils import vprint, vprint_banner
from hysop.tools.htypes import check_instance, first_not_None, to_list, to_tuple
[docs]
class Problem(ComputationalGraph):
def __new__(
cls, name=None, method=None, mpi_params=None, check_unique_clenv=True, **kwds
):
return super().__new__(cls, **kwds)
def __init__(
self, name=None, method=None, mpi_params=None, check_unique_clenv=True, **kwds
):
mpi_params = first_not_None(
mpi_params, MPIParams()
) # enforce mpi params for problems
super().__init__(name=name, method=method, mpi_params=mpi_params, **kwds)
self._do_check_unique_clenv = check_unique_clenv
self.search_intertasks_ops = None
self.ops_tasks = []
[docs]
@debug
def insert(self, *ops):
for node in ops:
if hasattr(node, "mpi_params") and node.mpi_params:
self.ops_tasks.append(node.mpi_params.task_id)
if hasattr(node, "impl_kwds") and "mpi_params" in node.impl_kwds:
self.ops_tasks.append(node.impl_kwds["mpi_params"].task_id)
given_ops_have_tasks = True
pb_task_id = (
HYSOP_DEFAULT_TASK_ID
if self.mpi_params is None
else self.mpi_params.task_id
)
if len(set(self.ops_tasks)) == 1 and self.ops_tasks[0] == pb_task_id:
# Intertask is not needed this is a single task-problem
given_ops_have_tasks = False
self.search_intertasks_ops = given_ops_have_tasks
self.push_nodes(*ops)
return self
[docs]
@debug
def build(
self,
args=None,
allow_subbuffers=False,
outputs_are_inputs=True,
search_intertasks_ops=None,
):
with Timer() as tm:
msg = self.build_problem(
args=args,
allow_subbuffers=allow_subbuffers,
outputs_are_inputs=outputs_are_inputs,
search_intertasks_ops=search_intertasks_ops,
)
if msg:
msg = f" Problem {msg} achieved, exiting ! "
vprint_banner(msg, at_border=2)
sys.exit(0)
comm = self.mpi_params.comm
if (not self.domain is None) and self.domain.has_tasks:
comm = self.domain.parent_comm
size = comm.Get_size()
avg_time = comm.allreduce(tm.interval) / size
msg = " Problem building took {} ({}s)"
if size > 1:
msg += f", averaged over {size} ranks. "
msg = msg.format(datetime.timedelta(seconds=round(avg_time)), avg_time)
vprint_banner(msg, spacing=True, at_border=2)
if (args is not None) and args.stop_at_build:
msg = " Problem has been built, exiting. "
vprint_banner(msg, at_border=2)
sys.exit(0)
[docs]
def build_problem(
self,
args,
allow_subbuffers,
outputs_are_inputs=True,
search_intertasks_ops=None,
):
if (args is not None) and args.stop_at_initialization:
return "initialization"
vprint("\nInitializing problem... " + str(self.name))
search_intertasks = search_intertasks_ops
if search_intertasks is None:
search_intertasks = self.search_intertasks_ops
self.initialize(
outputs_are_inputs=outputs_are_inputs,
topgraph_method=None,
is_root=True,
search_intertasks_ops=search_intertasks,
)
if (args is not None) and args.stop_at_discretization:
return "discretization"
vprint("\nDiscretizing problem... " + str(self.name))
for node in [_ for _ in self.nodes if isinstance(_, Problem)]:
node.discretize()
self.discretize()
if (args is not None) and args.stop_at_work_properties:
return "work properties retrieval"
vprint("\nGetting work properties... " + str(self.name))
work = self.get_work_properties()
if (args is not None) and args.stop_at_work_allocation:
return "work allocation"
vprint("\nAllocating work... " + str(self.name))
work.allocate(allow_subbuffers=allow_subbuffers)
if (args is not None) and args.stop_at_setup:
return "setup"
vprint("\nSetting up problem..." + str(self.name))
self.setup(work)
[docs]
def discretize(self):
super().discretize()
if self._do_check_unique_clenv:
self.check_unique_clenv()
[docs]
def check_unique_clenv(self):
cl_env, first_op = None, None
for op in self.nodes:
for topo in set(op.input_fields.values()).union(
set(op.output_fields.values())
):
if topo is not None and (topo.backend.kind == Backend.OPENCL):
if cl_env is None:
first_op = op
cl_env = topo.backend.cl_env
elif topo.backend.cl_env is not cl_env:
msg = ""
msg += "\nOpenCl environment mismatch between operator {} and operator {}."
msg = msg.format(first_op.name, op.name)
msg += f"\n{cl_env}"
msg += "\n and"
msg += f"\n{topo.backend.cl_env}"
msg += "\n If this is required, override check_unique_clenv()."
raise RuntimeError(msg)
[docs]
def initialize_field(self, field, mpi_params=None, **kwds):
"""Initialize a field on all its input and output topologies."""
initialized = set()
def __iterate_nodes(l):
for e in l:
if isinstance(e, Problem):
yield from __iterate_nodes(e.nodes)
yield e
# give priority to tensor field initialization
for op_fields in (
self.input_discrete_tensor_fields,
self.output_discrete_tensor_fields,
) + tuple(
_
for op in __iterate_nodes(self.nodes)
for _ in (
op.input_discrete_tensor_fields,
op.output_discrete_tensor_fields,
op.input_discrete_fields,
op.output_discrete_fields,
)
):
if field in op_fields:
dfield = op_fields[field]
if all((df in initialized) for df in dfield.discrete_fields()):
# all contained scalar fields were already initialized
continue
elif mpi_params and not all(
[
mpi_params.task_id == df.topology.task_id
for df in dfield.discrete_fields()
]
):
# Topology task does not matches given mpi_params task
continue
else:
components = ()
for component, scalar_dfield in dfield.nd_iter():
if scalar_dfield._dfield not in initialized:
components += (component,)
dfield.initialize(components=components, **kwds)
initialized.update(dfield.discrete_fields())
if not initialized:
msg = f"FATAL ERROR: Could not initialize field {field.name}."
raise RuntimeError(msg)
[docs]
@debug
@profile
def solve(
self,
simu,
dry_run=False,
dbg=None,
report_freq=10,
plot_freq=10,
checkpoint_handler=None,
**kwds,
):
if dry_run:
vprint()
vprint_banner("** Dry-run requested, skipping simulation. **")
return
simu.initialize()
check_instance(checkpoint_handler, CheckpointHandler, allow_none=True)
if not checkpoint_handler is None:
checkpoint_handler.create_checkpoint_template(self, simu)
checkpoint_handler.load_checkpoint(self, simu)
vprint("\nSolving problem...")
with Timer() as tm:
while not simu.is_over:
vprint()
simu.print_state()
self.apply(simulation=simu, dbg=dbg, **kwds)
should_dump_checkpoint = (
not checkpoint_handler is None
) and checkpoint_handler.should_dump(
simu
) # determined before simu advance
simu.advance(dbg=dbg, plot_freq=plot_freq)
if should_dump_checkpoint:
checkpoint_handler.save_checkpoint(self, simu)
if report_freq and (simu.current_iteration % report_freq) == 0:
self.profiler_report()
comm = self.mpi_params.comm
if (not self.domain is None) and self.domain.has_tasks:
comm = self.domain.parent_comm
size = comm.Get_size()
avg_time = comm.allreduce(tm.interval) / size
msg = " Simulation took {} ({}s)"
if size > 1:
msg += f", averaged over {size} ranks. "
msg += "\n for {} iterations ({}s per iteration) "
msg = msg.format(
datetime.timedelta(seconds=round(avg_time)),
avg_time,
max(simu.current_iteration + 1, 1),
avg_time / max(simu.current_iteration + 1, 1),
)
vprint_banner(msg, spacing=True, at_border=2)
simu.finalize()
if not checkpoint_handler is None:
checkpoint_handler.finalize(self.mpi_params)
self.final_report()
if dbg is not None:
dbg("final iteration", nostack=True)
[docs]
def final_report(self):
self.profiler_report()
if self.is_root or __DEBUG__ or self.__FORCE_REPORTS__:
vprint(self.task_profiler_report())
[docs]
@debug
def finalize(self):
vprint("Finalizing problem...")
super().finalize()